https://archive.ics.uci.edu/ml/datasets/mushroom
I have saved all of the data into a folder called ‘data’:
mush <- read_csv('data/mushrooms.csv')
## Rows: 8124 Columns: 23
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (22): class, cap-shape, cap-surface, cap-color, odor, gill-attachment, g...
## lgl (1): bruises
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
That data read bruises in as boolean (logical) values - we can make the change
mush <- read_csv('data/mushrooms.csv',
col_types = cols(bruises = col_character()))
glimpse(mush)
## Rows: 8,124
## Columns: 23
## $ class <chr> "p", "e", "e", "p", "e", "e", "e", "e", "p"…
## $ `cap-shape` <chr> "x", "x", "b", "x", "x", "x", "b", "b", "x"…
## $ `cap-surface` <chr> "s", "s", "s", "y", "s", "y", "s", "y", "y"…
## $ `cap-color` <chr> "n", "y", "w", "w", "g", "y", "w", "w", "w"…
## $ bruises <chr> "t", "t", "t", "t", "f", "t", "t", "t", "t"…
## $ odor <chr> "p", "a", "l", "p", "n", "a", "a", "l", "p"…
## $ `gill-attachment` <chr> "f", "f", "f", "f", "f", "f", "f", "f", "f"…
## $ `gill-spacing` <chr> "c", "c", "c", "c", "w", "c", "c", "c", "c"…
## $ `gill-size` <chr> "n", "b", "b", "n", "b", "b", "b", "b", "n"…
## $ `gill-color` <chr> "k", "k", "n", "n", "k", "n", "g", "n", "p"…
## $ `stalk-shape` <chr> "e", "e", "e", "e", "t", "e", "e", "e", "e"…
## $ `stalk-root` <chr> "e", "c", "c", "e", "e", "c", "c", "c", "e"…
## $ `stalk-surface-above-ring` <chr> "s", "s", "s", "s", "s", "s", "s", "s", "s"…
## $ `stalk-surface-below-ring` <chr> "s", "s", "s", "s", "s", "s", "s", "s", "s"…
## $ `stalk-color-above-ring` <chr> "w", "w", "w", "w", "w", "w", "w", "w", "w"…
## $ `stalk-color-below-ring` <chr> "w", "w", "w", "w", "w", "w", "w", "w", "w"…
## $ `veil-type` <chr> "p", "p", "p", "p", "p", "p", "p", "p", "p"…
## $ `veil-color` <chr> "w", "w", "w", "w", "w", "w", "w", "w", "w"…
## $ `ring-number` <chr> "o", "o", "o", "o", "o", "o", "o", "o", "o"…
## $ `ring-type` <chr> "p", "p", "p", "p", "e", "p", "p", "p", "p"…
## $ `spore-print-color` <chr> "k", "n", "n", "k", "n", "k", "k", "n", "k"…
## $ population <chr> "s", "n", "n", "s", "a", "n", "n", "s", "v"…
## $ habitat <chr> "u", "g", "m", "u", "g", "g", "m", "m", "g"…
Check for duplicates:
dim(distinct(mush))
## [1] 8124 23
dim(mush)
## [1] 8124 23
No duplicates here. Check for missing values:
dim(na.omit(mush))
## [1] 8124 23
dim(mush)
## [1] 8124 23
No missing data.
Check the distribution of our ‘class’ feature - this is our target variable: e = edible p = poisonous
table(mush$class)
##
## e p
## 4208 3916
Some algorithms we use don’t do well with column names that have
hyphens - in them.
We can use gsub for pattern replacement and regex to
change hyphens to underscores. We have to use \\ to escape
the special meaning of -
colnames(mush) <- gsub('\\-', '_', colnames(mush))
Here is our objective - is something edible or not. We want to turn this into a binary variable because that’s what different methods try to use.
We will use mutate to create a new column called
edible based on the class column:
mush <- mush %>%
mutate(edible = ifelse(class == 'e', 1, 0))
Then check to see if that worked:
table(mush$class, mush$edible)
##
## 0 1
## e 0 4208
## p 3916 0
All of the variables are categorical. So in order to a PCA on this
data, we have to turn it into a set of dummy variables. This will make a
very wide dataset of yes or no (0, 1) for each feature value. We remove
our target variable edible and our original
class values.
mush_d <- fastDummies::dummy_cols(mush) %>%
select_if(is.numeric) %>%
select(-edible, -class_p, -class_e)
head(mush_d)
## # A tibble: 6 × 117
## cap_shape_b cap_shape_c cap_shape_f cap_shape_k cap_shape_s cap_shape_x
## <int> <int> <int> <int> <int> <int>
## 1 0 0 0 0 0 1
## 2 0 0 0 0 0 1
## 3 1 0 0 0 0 0
## 4 0 0 0 0 0 1
## 5 0 0 0 0 0 1
## 6 0 0 0 0 0 1
## # ℹ 111 more variables: cap_surface_f <int>, cap_surface_g <int>,
## # cap_surface_s <int>, cap_surface_y <int>, cap_color_b <int>,
## # cap_color_c <int>, cap_color_e <int>, cap_color_g <int>, cap_color_n <int>,
## # cap_color_p <int>, cap_color_r <int>, cap_color_u <int>, cap_color_w <int>,
## # cap_color_y <int>, bruises_f <int>, bruises_t <int>, odor_a <int>,
## # odor_c <int>, odor_f <int>, odor_l <int>, odor_m <int>, odor_n <int>,
## # odor_p <int>, odor_s <int>, odor_y <int>, gill_attachment_a <int>, …
Lots of different dimensions. An explosion of arrows, and nothing is dominating.
mush_pca <- FactoMineR::PCA(mush_d, ncp = 20)
library(factoextra)
## Welcome! Want to learn more? See two factoextra-related books at https://goo.gl/ve3WBa
fviz_eig(mush_pca, ncp = 20)
We can plot dim1 and dim2. Wow, some interesting structures here. It likes there are some weird shadows happening. Like each mushroom has a set dopplegangers that are just slightly different.
fviz_pca_ind(mush_pca, geom = 'point', alpha = 0.2)
We can have a better look - make a dataframe out of our new
coordinates and add in our original edible column as a new
column so we plot it.
mush_pca_d <- mush_pca$ind$coord %>%
as.data.frame()
mush_pca_d$edible <- as.factor(mush$edible)
We definitely see some interesting clusters:
ggplot(mush_pca_d) +
geom_point(aes(x = Dim.1, y = Dim.2, color = edible), alpha = 0.3) +
scale_color_manual(values = c('firebrick','dodgerblue4')) +
theme_minimal()
calculate the proportions of 1 and 0
table(mush$edible) / length(mush$edible)
##
## 0 1
## 0.4820286 0.5179714
Then you can calculate the entropy (from our lecture):
entropy = −(𝑝_1 log(𝑝_1 )− 𝑝_2 log(𝑝_2 )− …)
x <- table(mush$edible) / length(mush$edible)
-( x[1] * log2(x[1]) + x[2] * log2(x[2]))
## 0
## 0.9990679
Let’s create a reusable function for calculating entropy:
entropy <- function(x) {
x <- table(x) / length(x)
-1 * sum(sapply(x, function(q) q * log2(q)))
}
entropy(mush$edible)
## [1] 0.9990679
entropy(mush$class)
## [1] 0.9990679
entropy(mush$gill_color)
## [1] 3.030433
Plot by gill colour
Here we are building a function, which we call
entropy_plot(). The function is designed to visualize how
the entropy of the target variable (class) varies across
the levels of a categorical variable (v). Entropy is a
measure of uncertainty or randomness, so higher entropy values indicate
greater variability in the class variable within a given
level of v.
Input:
v: A string representing the name of a categorical
variable in the mush dataset. By default, it is set to
'gill_color'.Output:
class
across the levels of the specified variable v.entropy_plot <- function(v = 'gill_color') {
d <- mush
d$f <- d[[v]]
d <- d %>%
group_by(f) %>%
summarize(entropy = entropy(class), n = n()) %>%
mutate(p = n / sum(n)) %>%
arrange(p, entropy) %>%
mutate(f = fct_reorder(f, entropy))
print(d)
d <- d %>%
mutate(f_num = as.numeric(f)) %>%
arrange(f_num) %>%
mutate(lag_p = lag(p)) %>%
mutate(lag_p = ifelse(is.na(lag_p), 0, lag_p)) %>%
mutate(x_min = cumsum(lag_p)) %>%
mutate(x_max = lead(x_min)) %>%
mutate(x_max = ifelse(is.na(x_max), 1, x_max))
ggplot(d) +
geom_rect(ymin = 0, aes(xmin = x_min, xmax = x_max, ymax = entropy),
color = 'grey50', fill = 'lightblue') +
geom_text(y = 0, vjust = -1,
aes(x = ((x_max - x_min) / 2) + x_min, label = f)) +
scale_x_continuous(name = v, breaks = c(0, 1), limits = c(0, 1)) +
scale_y_continuous(breaks = c(0, 1), limits = c(0, 1)) +
theme_minimal() +
theme(axis.ticks = element_blank(), panel.grid.minor = element_blank())
}
We can use gill_color as an argument or leave it blank,
as it was set a default.
If you call entropy_plot('gill_color'), the function
will:
Calculate the entropy of class for each level of
gill_color.
Create a plot where:
The x-axis represents the proportion of observations for each
gill_color.
The y-axis represents the entropy of class for each
gill_color.
Each rectangle corresponds to a level of gill_color,
with its height representing entropy and its width representing
proportion.
entropy_plot('gill_color')
## # A tibble: 12 × 4
## f entropy n p
## <fct> <dbl> <int> <dbl>
## 1 r 0 24 0.00295
## 2 o 0 64 0.00788
## 3 y 0.820 86 0.0106
## 4 e 0 96 0.0118
## 5 k 0.627 408 0.0502
## 6 u 0.461 492 0.0606
## 7 h 0.854 732 0.0901
## 8 g 0.915 752 0.0926
## 9 n 0.490 1048 0.129
## 10 w 0.731 1202 0.148
## 11 p 0.985 1492 0.184
## 12 b 0 1728 0.213
Lets split We can use our function on spore print colour to create a
visualisation that shows the entropy of our target variable
(
class) across different levels of spore print colour.
entropy_plot('spore_print_color')
## # A tibble: 9 × 4
## f entropy n p
## <fct> <dbl> <int> <dbl>
## 1 b 0 48 0.00591
## 2 o 0 48 0.00591
## 3 u 0 48 0.00591
## 4 y 0 48 0.00591
## 5 r 0 72 0.00886
## 6 h 0.191 1632 0.201
## 7 k 0.528 1872 0.230
## 8 n 0.511 1968 0.242
## 9 w 0.797 2388 0.294
Let’s split on smell
entropy_plot('odor')
## # A tibble: 9 × 4
## f entropy n p
## <fct> <dbl> <int> <dbl>
## 1 m 0 36 0.00443
## 2 c 0 192 0.0236
## 3 p 0 256 0.0315
## 4 a 0 400 0.0492
## 5 l 0 400 0.0492
## 6 s 0 576 0.0709
## 7 y 0 576 0.0709
## 8 f 0 2160 0.266
## 9 n 0.214 3528 0.434
library(rpart)
library(rpart.plot)
Make our first tree and plot - we don’t need the dummy coding for this
mush_tree <- rpart(class ~ ., data = mush)
mush_tree
## n= 8124
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 8124 3916 e (0.5179714 0.4820286)
## 2) edible>=0.5 4208 0 e (1.0000000 0.0000000) *
## 3) edible< 0.5 3916 0 p (0.0000000 1.0000000) *
rpart.plot(mush_tree)
Um, we left the edible variable in there. Our tree found
a perfect split with this one variable! This is called “leakage”.
Let’s remove the variable we’re actually trying to predict and try again.
mush_tree <- rpart(class ~ ., data = select(mush, -edible),
model = T)
rpart.plot(mush_tree)
Variable importance - information gain.
mush_tree$variable.importance
## odor spore_print_color gill_color
## 3823.407 2833.749 2322.460
## stalk_surface_above_ring stalk_surface_below_ring ring_type
## 2034.584 2030.555 2026.526
Maybe you’re not so good at discerning smells. Try removing
odor from the data and running it again. Overwrite the
mush_tree model so we can explore it below.
# your code here
mush_tree <- rpart(class ~ ., data = select(mush, -edible, -odor),
model = T)
rpart.plot(mush_tree)
A plot of the importance of each variable based on the information it provides.
x <- tibble(var_name = names(mush_tree$variable.importance),
importance = mush_tree$variable.importance) %>%
mutate(var_name = fct_reorder(var_name, importance))
ggplot(x, aes(y = var_name, x = importance)) +
geom_vline(xintercept = 0) +
geom_point() +
theme_minimal()
Restrict the depth of the tree. Try changing the
maxdepth parameter and see how it changes the tree.
mush_tree <- rpart(class ~ ., data = select(mush, -edible, -odor),
model = T, maxdepth=3)
rpart.plot(mush_tree)
Read in the data from the worksheet.
loan <- read_csv('data/week_05_age_balance.csv')
## Rows: 32 Columns: 3
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (1): loan
## dbl (2): age, balance
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
glimpse(loan)
## Rows: 32
## Columns: 3
## $ age <dbl> 43.10638, 29.41489, 23.00000, 25.96809, 31.23404, 32.76596, 36…
## $ balance <dbl> 0.000000, 2.505695, 19.362187, 37.585421, 21.867882, 34.168565…
## $ loan <chr> "no", "no", "no", "no", "no", "no", "no", "no", "no", "no", "n…
Create a simple scatter plot of the data. See how the domains of the two groups can possibly be divided.
ggplot(loan) +
geom_point(aes(x = balance, y = age, shape = loan, color = loan), size = 5) +
theme_minimal()
Play around with the slope and intercept
and until it separates the classes well.
ggplot(loan) +
geom_point(aes(x = balance, y = age, shape = loan, color = loan), size = 5) +
geom_abline(slope = -.32, intercept = 60,
linetype = 2, size = 1) +
theme_minimal() +
scale_x_continuous(labels = scales::dollar) +
labs(title = "Can safe loans and risky loans be separated?") +
xlab("Bank balance ($K)") + ylab("Age")
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
Let’s try something we know, a decision tree!
loan_tree <- rpart(loan ~ balance + age, data = loan)
rpart.plot(loan_tree)
Make a single prediction
predict(loan_tree, newdata = data.frame(balance = 950
, age = 20))
## no yes
## 1 0.07142857 0.9285714
Visualize the results
test_grid <- expand.grid(balance = seq(min(loan$balance), max(loan$balance), length.out = 20),
age = seq(min(loan$age), max(loan$age), length.out = 20)) %>%
as_tibble
test_grid$pred <- predict(loan_tree, newdata = test_grid)[,'yes']
test_grid_plot <- function(title = '') {
ggplot(loan) +
geom_tile(data = test_grid, aes(x = balance, y = age, fill = pred)) +
geom_point(aes(x = balance, y = age, shape = loan, color = loan), size = 5) +
scale_fill_gradient(low = 'white', high = '#4D9221') +
scale_color_manual(values = c('blue', 'black')) +
labs(title = title) +
theme_minimal()
}
test_grid_plot('Decision Tree (one split)')
The minsplit is the minimum number of cases that have to
be in a split before you can make the split. This is a crucial part of
tuning a decision tree. Higher minsplit=simpler model; lower
minsplit=more complex trees.
loan_tree <- rpart(loan ~ balance + age, data = loan,
control = rpart.control(minsplit = 4))
rpart.plot(loan_tree)
test_grid$pred <- predict(loan_tree, newdata = test_grid)[,'yes']
test_grid_plot('Decision Tree (three splits split)')
Make a single prediction
predict(loan_tree, newdata = data.frame(balance = 100, age = 200))
## no yes
## 1 0.07142857 0.9285714
Something we will use often throughout the term is a random forest model - an ensemble of decision trees fit on random bootstraps of the data and subsets of the variables.
library(randomForest)
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
loan_rf <- randomForest::randomForest(as.factor(loan) ~ balance + age, data = loan)
test_grid$pred <- predict(loan_rf, newdata = test_grid, type = 'prob')[,2]
test_grid_plot('Random Forest')
A logistic regression is effectively a “one-node neural network”. It uses a linear combination of the inputs and produces an output. It minimises maximum likelihood rather than minimising entropy.
d <- loan %>%
mutate(yes_loan = ifelse(loan == 'yes', 1, 0))
loan_log <- glm(yes_loan ~ balance + age, data = d, family = 'binomial')
test_grid$pred <- predict(loan_log, newdata = test_grid)
test_grid$pred <- exp(test_grid$pred) / (1+ exp(test_grid$pred))
test_grid_plot('Logistic Regression')
Support Vector Machines - very good for large feature space, e.g textual data.
loan_svm <- e1071::svm(as.factor(loan) ~ balance + age, data = loan, probability = T)
test_grid$pred <- predict(loan_svm, newdata = test_grid, probability = T) %>%
{attr(., 'probabilities')[,2]}
test_grid_plot('Support Vector Machines')
We can stop here for today ————————————
Read in the data for the curve fitting. This really doesn’t represent
anything. Our goal is to predict v2 using v1
as best as we can.
rm(x)
xkcd <- read_csv('data/curve_fitting.csv')
## Rows: 31 Columns: 2
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## dbl (2): v1, v2
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
Quick scatterplot.
ggplot(xkcd, aes(x = v1, y = v2)) +
geom_point() +
theme_classic()
Ordinary least squares optimization. This is the squared error function.
squared_error <- function(param) {
e <- xkcd$v1 * param[1] + param[2]
sum((e - xkcd$v2)^2)
}
Find a line that minimizes the squared error. Change the two numbers in the first line - they represent the slope and intercept.
param <- c(0.8, 0)
ggplot(xkcd) +
geom_abline(slope = param[1], intercept = param[2]) +
geom_point(aes(x = v1, y = v2)) +
annotate('text', x = 0, y = 80,
label = paste('Error: ', round(squared_error(param), 1)),
hjust = 0) +
theme_classic()
An optimizer will explore the combination of parameters and find the
one that minimizes or maximizes a function that you give it. In this
case, we are minimizing squared_error using the two
parameters of slope and intercept.
o <- optim(c(a = 1, b= 1), squared_error, control = list(trace = T))
## Nelder-Mead direct search function minimizer
## function value for initial parameters = 29863.252178
## Scaled convergence tolerance is 0.000444997
## Stepsize computed as 0.100000
## BUILD 3 39568.603649 29863.252178
## EXTENSION 5 29960.027011 18101.040404
## EXTENSION 7 29863.252178 15960.003458
## LO-REDUCTION 9 18101.040404 15960.003458
## HI-REDUCTION 11 16349.042861 15960.003458
## HI-REDUCTION 13 16008.632987 15928.973612
## HI-REDUCTION 15 15960.003458 15810.009918
## REFLECTION 17 15928.973612 15799.398399
## HI-REDUCTION 19 15825.142800 15799.398399
## HI-REDUCTION 21 15810.009918 15799.398399
## EXTENSION 23 15804.441083 15782.536635
## EXTENSION 25 15799.398399 15769.673753
## EXTENSION 27 15782.536635 15718.654324
## EXTENSION 29 15769.673753 15694.035999
## EXTENSION 31 15718.654324 15545.459542
## LO-REDUCTION 33 15694.035999 15545.459542
## EXTENSION 35 15553.626672 15269.398676
## LO-REDUCTION 37 15545.459542 15269.398676
## EXTENSION 39 15378.402908 14994.052425
## EXTENSION 41 15269.398676 14726.236377
## REFLECTION 43 14994.052425 14675.791793
## REFLECTION 45 14726.236377 14572.554056
## HI-REDUCTION 47 14675.791793 14572.554056
## REFLECTION 49 14634.963160 14546.466524
## LO-REDUCTION 51 14572.554056 14546.466524
## LO-REDUCTION 53 14560.392106 14546.466524
## LO-REDUCTION 55 14547.554491 14546.466524
## HI-REDUCTION 57 14546.550985 14544.984393
## HI-REDUCTION 59 14546.466524 14544.340286
## LO-REDUCTION 61 14544.984393 14544.340286
## HI-REDUCTION 63 14544.912978 14544.340286
## LO-REDUCTION 65 14544.520412 14544.303190
## LO-REDUCTION 67 14544.340286 14544.303190
## HI-REDUCTION 69 14544.307992 14544.291244
## HI-REDUCTION 71 14544.303190 14544.280466
## HI-REDUCTION 73 14544.291244 14544.278246
## LO-REDUCTION 75 14544.280466 14544.276430
## HI-REDUCTION 77 14544.278246 14544.275924
## LO-REDUCTION 79 14544.276430 14544.275924
## HI-REDUCTION 81 14544.276239 14544.275695
## Exiting from Nelder Mead minimizer
## 83 function evaluations used
And here is the line that it found. Did you get a line that was close?
param <- o$par
ggplot(xkcd) +
geom_abline(slope = param[1], intercept = param[2]) +
geom_point(aes(x = v1, y = v2)) +
annotate('text', x = 0, y = 80,
label = paste('Error: ', round(squared_error(param), 1)),
hjust = 0) +
theme_classic()
GGplot’s smooth function uses a LOESS regression by default.
ggplot(xkcd, aes(x = v1, y = v2)) +
geom_point() +
theme_classic() +
geom_smooth()
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
You have to specify if you want a normal linear regression.
ggplot(xkcd, aes(x = v1, y = v2)) +
geom_point() +
theme_classic() +
geom_smooth(method = 'lm')
## `geom_smooth()` using formula = 'y ~ x'
To reproduce what geom_smooth does we can fit the linear
regression first, pull the predictions, then plot a line and a
ribbon.
xkcd_lm <- lm(v2 ~ v1, data = xkcd)
x1 <- cbind(xkcd, predict(xkcd_lm, interval = 'confidence'))
ggplot(x1) +
geom_ribbon(aes(x = v1, ymin = lwr, ymax = upr),
fill = 'darkorchid1', alpha = 0.2) +
geom_line(aes(x = v1, y = fit),
color = 'darkorchid4', size = 1) +
geom_point(aes(x = v1, y = v2)) +
theme_classic()
Polynomial regression
f <- lm(v2 ~ v1 + I(v1^2), data = xkcd)
x1 <- cbind(xkcd, predict(f, interval = 'confidence'))
ggplot(x1) +
geom_ribbon(aes(x = v1, ymin = lwr, ymax = upr), alpha = 0.2) +
geom_line(aes(x = v1, y = fit), color = 'blue', size = 1) +
geom_point(aes(x = v1, y = v2)) +
theme_classic()
A 10-degree polynomial regression.
f <- lm(v2 ~ v1 + poly(v1, 10), data = xkcd)
x1 <- cbind(xkcd, predict(f, interval = 'confidence'))
ggplot(x1) +
geom_ribbon(aes(x = v1, ymin = lwr, ymax = upr), alpha = 0.2) +
geom_line(aes(x = v1, y = fit), color = 'blue', size = 1) +
geom_point(aes(x = v1, y = v2)) +
theme_classic()
Polynomial regressions get crazy outside of their boundaries.
x1 <- data.frame(v1 = seq(-10, 110, length.out = 200))
x1$pred <- predict(f, newdata = x1)
ggplot(xkcd) +
geom_line(data = x1, aes(x = v1, y = pred), color = 'blue', size = 1) +
geom_point(data = xkcd, aes(x = v1, y = v2)) +
theme_classic()
We can use decision trees for regression! It just breaks the curve up into chunks and computes the mean of that chunk. In this case two chunks.
xkcd_tree <- rpart(v2 ~ v1, data = xkcd)
rpart.plot(xkcd_tree)
And it always looks jagged when you plot it.
x1 <- cbind(xkcd, f = predict(xkcd_tree))
ggplot(x1) +
geom_line(aes(x = v1, y = f), color = 'blue', size = 2) +
geom_point(aes(x = v1, y = v2)) +
theme_classic()
The mean squared error of the tree beats the ordinary least squares in terms of it’s own metric of squared error.
mean_squared_error_dt <- function(f) {
x1 <- cbind(xkcd, f = predict(f))
x1 %>%
mutate(resid = v2 - f) %>%
mutate(resid2 = resid^2) %>%
summarize(mse = mean(resid2)) %>%
pull(mse)
}
mean_squared_error_dt(xkcd_tree)
## [1] 384.3234
mean_squared_error_dt(xkcd_lm)
## [1] 469.1702
We can make the decision tree way more complex, and it smashes the squared error estimation of the linear regression.
xkcd_tree <- rpart(v2 ~ v1, data = xkcd,
control = rpart.control(minsplit = 2))
rpart.plot(xkcd_tree)
mean_squared_error_dt(xkcd_tree)
## [1] 70.14077
But is this really the best model? It found some weird parts of the data and fit too them very closely.
x1 <- cbind(xkcd, f = predict(xkcd_tree))
ggplot(x1) +
geom_line(aes(x = v1, y = f), color = 'blue', size = 1) +
geom_point(aes(x = v1, y = v2)) +
theme_classic()
We will think about what makes a good model next time.